{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f411b65",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp oauth"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0865b3d0",
   "metadata": {},
   "source": [
    "# OAuth\n",
    "> Basic scaffolding for handling OAuth\n",
    "\n",
    "- eval: false\n",
    "- skip_exec: true"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "507cd009",
   "metadata": {},
   "source": [
    "This is not yet thoroughly tested. See the [docs page](https://docs.fastht.ml/explains/oauth.html) for an explanation of how to use this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "793722f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from fasthtml.common import *\n",
    "from oauthlib.oauth2 import WebApplicationClient\n",
    "from urllib.parse import urlencode, parse_qs, quote, unquote\n",
    "from httpx import get, post\n",
    "import secrets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a078133",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class _AppClient(WebApplicationClient):\n",
    "    def __init__(self, client_id, client_secret, redirect_uri, code=None, scope=None, **kwargs):\n",
    "        super().__init__(client_id, code=code, scope=scope, **kwargs)\n",
    "        self.client_secret,self.redirect_uri = client_secret,redirect_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d82ea17d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class GoogleAppClient(_AppClient):\n",
    "    \"A `WebApplicationClient` for Google oauth2\"\n",
    "    base_url = \"https://accounts.google.com/o/oauth2/v2/auth\"\n",
    "    token_url = \"https://www.googleapis.com/oauth2/v4/token\"\n",
    "    info_url = \"https://www.googleapis.com/oauth2/v3/userinfo\"\n",
    "    id_key = 'sub'\n",
    "    \n",
    "    def __init__(self, client_id, client_secret, redirect_uri=None, redirect_uris=None, code=None, scope=None, **kwargs):\n",
    "        if redirect_uris and not redirect_uri: redirect_uri = redirect_uris[0]\n",
    "        scope_pre = \"https://www.googleapis.com/auth/userinfo\"\n",
    "        if not scope: scope=[\"openid\", f\"{scope_pre}.email\", f\"{scope_pre}.profile\"]\n",
    "        super().__init__(client_id, client_secret, redirect_uri, code=code, scope=scope, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "371ab1f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class GitHubAppClient(_AppClient):\n",
    "    \"A `WebApplicationClient` for GitHub oauth2\"\n",
    "    base_url = \"https://github.com/login/oauth/authorize\"\n",
    "    token_url = \"https://github.com/login/oauth/access_token\"\n",
    "    info_url = \"https://api.github.com/user\"\n",
    "    id_key = 'id'\n",
    "\n",
    "    def __init__(self, client_id, client_secret, redirect_uri, code=None, scope=None, **kwargs):\n",
    "        if not scope: scope=\"user\"\n",
    "        super().__init__(client_id, client_secret, redirect_uri, code=code, scope=scope, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e79f996",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class HuggingFaceClient(_AppClient):\n",
    "    \"A `WebApplicationClient` for HuggingFace oauth2\"\n",
    "\n",
    "    base_url = \"https://huggingface.co/oauth/authorize\"\n",
    "    token_url = \"https://huggingface.co/oauth/token\"\n",
    "    info_url = \"https://huggingface.co/oauth/userinfo\"\n",
    "    id_key = 'sub'\n",
    "    \n",
    "    def __init__(self, client_id, client_secret, redirect_uri=None, redirect_uris=None, code=None, scope=None, state=None, **kwargs):\n",
    "        if redirect_uris and not redirect_uri: redirect_uri = redirect_uris[0]\n",
    "        if not scope: scope=[\"openid\",\"profile\"]\n",
    "        if not state: state=secrets.token_urlsafe(16)\n",
    "        super().__init__(client_id, client_secret, redirect_uri, code=code, scope=scope, state=state, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f037bd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class DiscordAppClient(_AppClient):\n",
    "    \"A `WebApplicationClient` for Discord oauth2\"\n",
    "    base_url = \"https://discord.com/oauth2/authorize\"\n",
    "    token_url = \"https://discord.com/api/oauth2/token\"\n",
    "    revoke_url = \"https://discord.com/api/oauth2/token/revoke\"\n",
    "    id_key = 'id'\n",
    "\n",
    "    def __init__(self, client_id, client_secret, redirect_uri, is_user=False, perms=0, scope=None, **kwargs):\n",
    "        if not scope: scope=\"applications.commands applications.commands.permissions.update identify\"\n",
    "        self.integration_type = 1 if is_user else 0\n",
    "        self.perms = perms\n",
    "        super().__init__(client_id, client_secret, redirect_uri, scope=scope, **kwargs)\n",
    "\n",
    "    def login_link(self):\n",
    "        d = dict(response_type='code', client_id=self.client_id,\n",
    "                 integration_type=self.integration_type, scope=self.scope,\n",
    "                 redirect_uri=self.redirect_uri) #, permissions=self.perms, prompt='consent')\n",
    "        return f'{self.base_url}?' + urlencode(d)\n",
    "\n",
    "    def parse_response(self, code):\n",
    "        headers = {'Content-Type': 'application/x-www-form-urlencoded'}\n",
    "        data = dict(grant_type='authorization_code', code=code, redirect_uri=self.redirect_uri)\n",
    "        r = post(self.token_url, data=data, headers=headers, auth=(self.client_id, self.client_secret))\n",
    "        r.raise_for_status()\n",
    "        self.parse_request_body_response(r.text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "990d2310",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "def login_link(self:WebApplicationClient, scope=None):\n",
    "    \"Get a login link for this client\"\n",
    "    if not scope: scope=self.scope\n",
    "    return self.prepare_request_uri(self.base_url, self.redirect_uri, scope)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42ee9991",
   "metadata": {},
   "source": [
    "Generating a login link that sends the user to the OAuth provider is done with `client.login_link()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5039b291",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "def login_link_with_state(self:WebApplicationClient, scope=None, state=None):\n",
    "    \"Get a login link for this client\"\n",
    "    if not scope: scope=self.scope\n",
    "    if not state: state=self.state\n",
    "    return self.prepare_request_uri(self.base_url, self.redirect_uri, scope, state)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31627bcb",
   "metadata": {},
   "source": [
    "It can sometimes be useful to pass state to the OAuth provider, so that when the user returns you can pick up where they left off. This can be done by using the `login_link_with_state` function with a `state` parameter:\n",
    "\n",
    "TODO: do all providers support this the same way? This is only tested for HF atm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72700129",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "https://huggingface.co/oauth/authorize?response_type=code&client_id=YOUR_CLIENT_ID&redirect_uri=http%3A%2F%2Flocalhost%3A8000%2Fredirect&scope=openid+profile&state=test_state\n"
     ]
    }
   ],
   "source": [
    "client = HuggingFaceClient(\"YOUR_CLIENT_ID\",\"YOUR_CLIENT_SECRET\",redirect_uri)\n",
    "print(client.login_link_with_state(state=\"test_state\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "479878a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "def parse_response(self:_AppClient, code):\n",
    "    \"Get the token from the oauth2 server response\"\n",
    "    payload = dict(code=code, redirect_uri=self.redirect_uri, client_id=self.client_id,\n",
    "                   client_secret=self.client_secret, grant_type='authorization_code')\n",
    "    r = post(self.token_url, json=payload)\n",
    "    r.raise_for_status()\n",
    "    self.parse_request_body_response(r.text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6967dbb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "def get_info(self:_AppClient):\n",
    "    \"Get the info for authenticated user\"\n",
    "    headers = {'Authorization': f'Bearer {self.token[\"access_token\"]}'}\n",
    "    return get(self.info_url, headers=headers).json()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03702349",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "def retr_info(self:_AppClient, code):\n",
    "    \"Combines `parse_response` and `get_info`\"\n",
    "    self.parse_response(code)\n",
    "    return self.get_info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29f52061",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "def retr_id(self:_AppClient, code):\n",
    "    \"Call `retr_info` and then return id/subscriber value\"\n",
    "    return self.retr_info(code)[self.id_key]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d978e813",
   "metadata": {},
   "source": [
    "After logging in via the provider, the user will be redirected back to the supplied redirect URL. The request to this URL will contain a `code` parameter, which is used to get an access token and fetch the user's profile information. See [the explainanation here](https://docs.fastht.ml/explains/oauth.html) for a worked example. You can either:\n",
    "\n",
    "- use client.retr_info(code) to get all the profile information, or\n",
    "- use client.retr_id(code) to get just the user's ID.\n",
    "\n",
    "After either of these calls, you can also access the access token (used to revoke access, for example) with `client.token[\"access_token\"]`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "474e14b4",
   "metadata": {},
   "source": [
    "# Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d211e8e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#|hide\n",
    "import nbdev; nbdev.nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0f7a90b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
